import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc=new Scanner(System.in);
        int n=sc.nextInt();
        int m=sc.nextInt();
        int a=sc.nextInt();
        int b=sc.nextInt();
        long ret=0;
        for(long x=0;x<=Math.min(n/2,m);x++){
            long y=Math.min(n-2*x,(m-x)/2);
            ret=Math.max(ret,x*a+b*y);
        }
        System.out.println(ret);
    }
}